#=
/opt/openmpi/bin/mpiexec
=#

#在openmpi中运行openblas中的函数时经常有问题，是不是因为我用的是ifort编译的openmpi？
#改用intelmpi或者用MKL
#在利用MPI的情况下，BLAS的多线程实现会很大的干扰性能，用intelmpi+MKL SEQUENTIAL
#MKL的引用放在最上面
using MKL
MKL.set_threading_layer(MKL.THREADING_SEQUENTIAL)

include("../src/qmcmary.jl")
using ..qmcmary

using Random
using MPI


"""保存hs场的位型"""
function savehs(hscfg, head, fname)
    fil = open(fname, "a")
    write(fil, head)
    write(fil, "\n")
    Nt = size(hscfg)[1]
    for ni in Base.OneTo(Nt)
        write(fil, string(hscfg[ni, :]))
        write(fil, "\n")
    end
    close(fil)
end


"""
进行前向计算
"""
function forward(sp, Nt, Nx, Ng, psi, the; hsf=nothing)
    nx = sin.(psi).*sin.(the)
    ny = cos.(psi).*sin.(the)
    nz = cos.(the)
    comm = MPI.COMM_WORLD
    fname = "tmp$(MPI.Comm_rank(comm))_$(MPI.Comm_size(comm))"
    fil = open(fname, "w")
    close(fil)
    hscfg::Matrix{Int} = Matrix{Int}(undef, Nt, Nx)
    nwarm = 0
    if isnothing(hsf)
        for bi in Base.OneTo(Nt)
            cfg = 2(round.(rand(Nx)) .- 0.5)
            hscfg[bi, :] .= cfg
        end
        nwarm = 1
    else
        hscfg = hsf
    end
    sslen, bmats, allbmats, ss = initialize_SS(
        Nt, Ng, sp, hscfg, nx, ny, nz
    )
    for _ in Base.OneTo(nwarm)
        println("warm")
        hscfg, bmats, allbmats, ss, gf, ph = dqmc_step(
            Nt, Ng, sp, hscfg, nx, ny, nz, sslen, bmats, allbmats, ss
        )
    end
    #
    tph = 0.0
    nsp = 1
    for idx in Base.OneTo(nsp)
        hscfg, bmats, allbmats, ss, gf, ph = dqmc_step(
            Nt, Ng, sp, hscfg, nx, ny, nz, sslen, bmats, allbmats, ss
        )
        tph += ph
        savehs(hscfg, string(idx), fname)
    end
    return tph/nsp, hscfg
    #println("forward sgn: ", tph/nsp)
end


"""
导数
"""
function step_theta1(sp, Nt, Nx, Ng, psi, the)
    comm = MPI.COMM_WORLD
    fid = open("tmp$(MPI.Comm_rank(comm))_$(MPI.Comm_size(comm))", "r")
    #
    nx = sin.(psi).*sin.(the)
    ny = cos.(psi).*sin.(the)
    nz = cos.(the)
    #println(ny)
    #println(nz)
    #
    #重新读取cfg
    ttbar = zeros(Nx)
    tpbar = zeros(Nx)
    tph = 0.0
    bistr = "0"
    tapemap = pbpn_calc_tapemap(sp, nx, ny, nz)
    while !eof(fid)
        bistr = readline(fid)
        #println(bistr)
        hscfg = Matrix{Int}(undef, Nt, Nx)
        for bi in Base.OneTo(Nt)
            cfg = readline(fid)
            cfg = split(cfg[2:end-1], ',')
            for xi in Base.OneTo(Nx)
                hscfg[bi, xi] = strip(cfg[xi]) == "1" ? +1 : -1
            end
        end
        sslen, bmats, allbmats, ss = initialize_SS(
            Nt, Ng, sp, hscfg, nx, ny, nz
        )
        ###
        gf, ph = eq_green_scratch(ss)
        #println(ph)
        tph += ph
        #
        nxbar, nybar, nzbar = meas_grad(
            ss, allbmats, sp, hscfg, nx, ny, nz; tapemap=tapemap
        )
        #
        pnxpt = sin.(psi) .* cos.(the)
        pnypt = cos.(psi) .* cos.(the)
        pnzpt = -sin.(the)
        pnxpp = cos.(psi) .* sin.(the)
        pnypp = -sin.(psi) .* sin.(the)
        pnzpp = zeros(Nx)
        #求∂L/∂θ
        thetabar = @. nxbar * pnxpt + nybar * pnypt +  nzbar * pnzpt
        #println(thetabar)
        psibar = @. nxbar * pnxpp + nybar * pnypp +  nzbar * pnzpp
        #println(psibar)
        ttbar += real(thetabar)
        tpbar += real(psibar)
    end
    close(fid)
    nsp = parse(Int, bistr)
    #println(tph/nsp)
    return ttbar/nsp, tpbar/nsp
end



function optimal(L)
    MPI.Init()
    Nx = 3*L^2
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    Ui = 9*ones(Nx)
    Nt = 60
    Ng = 4
    psi = π*0.25*ones(Nx)
    the = π*0.25*ones(Nx)
    #
    comm = MPI.COMM_WORLD
    t_psi = 0
    m_psi = zeros(Nx)
    v_psi = zeros(Nx)
    t_the = 0
    m_the = zeros(Nx)
    v_the = zeros(Nx)
    hsfig = nothing
    #
    sp = default_splitting(Nt, hk, Ui)
    for st in Base.OneTo(100)
        sgn, hsfig = forward(sp, Nt, Nx, Ng, psi, the; hsf=hsfig)
        tbar, pbar = step_theta1(sp, Nt, Nx, Ng, psi, the)
        #Reduce会自动等执行完成, 用Allduce效果是一样的
        println("sgn: $(sgn) $(MPI.Comm_rank(comm))")
        sum = MPI.Reduce(sgn, +, 0, comm)
        sum = MPI.bcast(sum, 0, comm)
        sum = sum / MPI.Comm_size(comm)
        #
        tbarsum = MPI.Reduce(tbar, +, 0, comm)
        pbarsum = MPI.Reduce(pbar, +, 0, comm)
        if MPI.Comm_rank(comm) == 0
            tbarsum = tbarsum / MPI.Comm_size(comm)
            pbarsum = pbarsum / MPI.Comm_size(comm)
            println("forward sign: $(sum)")
            #println("t: $(tbarsum)")
            #println("p: $(pbarsum)")
            lg_the, m_the, v_the, t_the = next(Adam, tbarsum, m_the, v_the, t_the; α=0.02)
            lg_psi, m_psi, v_psi, t_psi = next(Adam, pbarsum, m_psi, v_psi, t_psi; α=0.02)
            the = the .+ lg_the
            psi = psi .+ lg_psi
            println("$(MPI.Comm_rank(comm)) the: $(the) psi: $(psi)")
        end
        the = MPI.bcast(the, 0, comm)
        psi = MPI.bcast(psi, 0, comm)
        #println("$(MPI.Comm_rank(comm)) mt: $(m_the) vt: $(v_the)")
        #println("$(MPI.Comm_rank(comm)) mt: $(m_psi) vt: $(v_psi)")
        println("$(MPI.Comm_rank(comm)) the: $(the) psi: $(psi)")
        #println("$(MPI.Comm_rank(comm)) hscfg: $(hsfig[1:3, :])")
    end
end

optimal(6)
